import torch
import yaml
from RawNet.model import RawNet
from RawNet40000.model import RawNet as RawNet40000
from RawNet48000.model import RawNet as RawNet48000

device = 'cuda' if torch.cuda.is_available() else 'cpu'
def raw4():
    model_path = r'F:\Attack-Transferability-On-Synthetic-Detection\ALL_MODEL\RAWNET40000\model_30_32_1024_lr0.0003_norm\epoch_25_99.0_98.1.pth'
    configyaml=r"F:\Attack-Transferability-On-Synthetic-Detection\RawNet40000\model_config_RawNet.yaml"
    with open(configyaml, 'r') as f_yaml:
        parser1 = yaml.load(f_yaml, Loader=yaml.FullLoader)
    model = RawNet40000(parser1['model'], device)
    model = (model).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    print('Model loaded : {}'.format(model_path))
    return model

def raw48():
    model_path = r"F:\Attack-Transferability-On-Synthetic-Detection\ALL_MODEL\RAWNET48000\model_30_32_1024_lr0.0003_norm\epoch_15_99.2_99.2.pth"
    configyaml=r"F:\Attack-Transferability-On-Synthetic-Detection\RawNet48000\model_config_RawNet.yaml"
    with open(configyaml, 'r') as f_yaml:
        parser1 = yaml.load(f_yaml, Loader=yaml.FullLoader)
    model = RawNet48000(parser1['model'], device)
    model = (model).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    print('Model loaded : {}'.format(model_path))
    return model

def raw646():
    model_path = r"F:\Attack-Transferability-On-Synthetic-Detection\ALL_MODEL\RAWNET\model_30_32_1024_lr0.0003_norm\epoch_29_99.7_99.5.pth"
    configyaml=r"F:\Attack-Transferability-On-Synthetic-Detection\RawNet\model_config_RawNet.yaml"
    with open(configyaml, 'r') as f_yaml:
        parser1 = yaml.load(f_yaml, Loader=yaml.FullLoader)
    model = RawNet(parser1['model'], device)
    model = (model).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    print('Model loaded : {}'.format(model_path))
    return model